{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "z5M7nhtK2_5y"
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "yvVx1gpC2_51"
   },
   "source": [
    "If the `econml` and `wget` python packages are not installed on your machine install them by running the cells below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "FsvTNisb2_52"
   },
   "outputs": [],
   "source": [
    "!pip install econml"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "pRRPaoDU2_52"
   },
   "outputs": [],
   "source": [
    "!pip install wget"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Asg9JcK52_52"
   },
   "source": [
    "Importing all the necessary components"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "OnDJYhWb2_53"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from statsmodels.api import OLS\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "from sklearn.model_selection import cross_val_predict\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "rvYTlxo72_53"
   },
   "outputs": [],
   "source": [
    "import wget\n",
    "import os\n",
    "\n",
    "if os.path.exists('datasets.py'):\n",
    "    os.remove('datasets.py')\n",
    "wget.download('https://raw.githubusercontent.com/CausalAIBook/MetricsMLNotebooks/main/T/datasets.py')\n",
    "\n",
    "if os.path.exists('myxgb.py'):\n",
    "    os.remove('myxgb.py')\n",
    "wget.download('https://raw.githubusercontent.com/CausalAIBook/MetricsMLNotebooks/main/T/myxgb.py')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VcCRDcgO2_53"
   },
   "source": [
    "# Setting the High Level Parameters for the Notebook\n",
    "\n",
    "The 401k dataset is downloaded from the source by the code and no need to further download anything:\n",
    "https://raw.githubusercontent.com/CausalAIBook/MetricsMLNotebooks/main/data/401k.csv\n",
    "\n",
    "The welfare dataset is downloaded from the source by the code and no need to further download anything:\n",
    "https://github.com/gsbDBI/ExperimentData/blob/master/Welfare/ProcessedData/welfarenolabel3.csv\n",
    "\n",
    "It is drawn from the analysis in this paper: [Green and Kern, 2012, Modeling Heterogeneous Treatment Effects in Survey Experiments with Bayesian Additive Regression Trees](https://github.com/gsbDBI/ExperimentData/blob/master/Welfare/Green%20and%20Kern%20BART.pdf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5UEDqOWd2_54"
   },
   "outputs": [],
   "source": [
    "dataset = 'welfare'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XMBrj69m2_54"
   },
   "outputs": [],
   "source": [
    "if dataset == '401k':\n",
    "    verbose = 0  # verbosity of auto-ml\n",
    "    n_splits = 5  # cross-fitting and cross-validation splits\n",
    "    cfit = False\n",
    "    data = '401k'  # which dataset, one of {'401k', 'criteo', 'welfare', 'poverty', 'charity'}\n",
    "    plot = True  # whether to plot results\n",
    "    xfeat = 'inc'  # feature to use as x axis in plotting, e.g. for criteo 'f1', for 401k 'inc', for welfare 'polviews'\n",
    "    # Formula for the BLP of CATE regression.\n",
    "    blp_formula = 'np.log(inc)'  # e.g. 'f1' for criteo, np.log(inc)' for 401k, 'C(polviews)' for the welfare case.\n",
    "    blp_formula_short = 'log(inc)'\n",
    "    blp_formula2 = 'np.log(inc) + np.power(np.log(inc), 2) + np.power(np.log(inc), 3) + np.power(np.log(inc), 4)'\n",
    "    blp_formula2_short = 'poly(log(inc), 4)'\n",
    "    control_feats = 'all'  # list of subset of features to be used as confounders or the string 'all' for everything\n",
    "    # list of subset of control features to be used for CATE model or the string 'all' for all controls\n",
    "    hetero_feats = 'all'\n",
    "    cov_clip = .01  # clipping of treatment variance p(x)*(1-p(x)), whenever used in inverse propensities\n",
    "    binary_y = False\n",
    "    random_seed = 1\n",
    "\n",
    "    # treatment policy to evaluate\n",
    "    def policy(x):\n",
    "        return x['inc'] > 30000\n",
    "\n",
    "    # cost of treatment when performing optimal policy learning, can also be viewed as \"threshold for treatment\"\n",
    "    treatment_cost = 4000\n",
    "\n",
    "    # For semi-synthetic data generation\n",
    "    semi_synth = False  # Whether true outcome y should be replaced by a fake outcome from a known CEF\n",
    "    simple_synth = True  # Whether the true CEF of the fake y should be simple or fitted from data\n",
    "    max_depth = 2  # max depth of random forest during for semi-synthetic model fitting\n",
    "    scale = .2  # magnitude of noise in semi-synthetic data\n",
    "\n",
    "elif dataset == 'welfare':\n",
    "    verbose = 0  # verbosity of auto-ml\n",
    "    n_splits = 5  # cross-fitting and cross-validation splits\n",
    "    cfit = True\n",
    "    data = 'welfare'  # which dataset, one of {'401k', 'criteo', 'welfare', 'poverty', 'charity'}\n",
    "    plot = True  # whether to plot results\n",
    "    # feature to use as x axis in plotting, e.g. for criteo 'f1', for 401k 'inc', for welfare 'polviews'\n",
    "    xfeat = 'polviews'\n",
    "    # Formula for the BLP of CATE regression.\n",
    "    blp_formula = 'C(polviews)'  # e.g. 'f1' for criteo, np.log(inc)' for 401k, 'C(polviews)' for the welfare case.\n",
    "    blp_formula_short = 'C(polviews)'\n",
    "    blp_formula2 = 'polviews'\n",
    "    blp_formula2_short = 'polviews'\n",
    "    control_feats = 'all'  # list of subset of features to be used as confounders or the string 'all' for everything\n",
    "    # list of subset of control features to be used for CATE model or the string 'all' for all controls\n",
    "    hetero_feats = 'all'\n",
    "    cov_clip = .01  # clipping of treatment variance p(x)*(1-p(x)), whenever used in inverse propensities\n",
    "    binary_y = True\n",
    "    random_seed = 1\n",
    "\n",
    "    # treatment policy to evaluate\n",
    "    def policy(x):\n",
    "        return x['polviews'] > 0\n",
    "\n",
    "    # cost of treatment when performing optimal policy learning, can also be viewed as \"threshold for treatment\"\n",
    "    treatment_cost = -.3\n",
    "\n",
    "    # For semi-synthetic data generation\n",
    "    semi_synth = False  # Whether true outcome y should be replaced by a fake outcome from a known CEF\n",
    "    simple_synth = True  # Whether the true CEF of the fake y should be simple or fitted from data\n",
    "    max_depth = 2  # max depth of random forest during for semi-synthetic model fitting\n",
    "    scale = .2  # magnitude of noise in semi-synthetic data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def simple_true_cef(D, X):  # simple CEF of the outcome for semi-synthetic data\n",
    "    return .5 * np.array(X)[:, 1] * D + np.array(X)[:, 1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VWkRUz3I2_54"
   },
   "source": [
    "# Fetching and Constructing the Dataset\n",
    "\n",
    "The data generator also allows for semi-synthetic data generation where the ground truth CATE is known, which can be used for evaluation of different methods."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "F6RgjgO-2_54"
   },
   "outputs": [],
   "source": [
    "from datasets import fetch_data_generator\n",
    "\n",
    "get_data, abtest, true_cef, true_cate = fetch_data_generator(data=data, semi_synth=semi_synth,\n",
    "                                                             simple_synth=simple_synth,\n",
    "                                                             scale=scale, true_f=simple_true_cef,\n",
    "                                                             max_depth=max_depth)\n",
    "X, D, y, groups = get_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "gORrK-bj2_55"
   },
   "outputs": [],
   "source": [
    "if semi_synth:\n",
    "    true_ate = np.mean(true_cate(X))\n",
    "    print(f'True ATE: {true_ate}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "cX4v7fQ82_55"
   },
   "outputs": [],
   "source": [
    "def rmse(cate, preds):\n",
    "    return np.sqrt(np.mean((cate - preds)**2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bhaPwyWJ2_55"
   },
   "source": [
    "# Data Analysis\n",
    "\n",
    "We now have our data $X$, $D$, $y$, of controls, treatments and outcomes. In some datasets, we also have \"groups\", also known as \"clusters\". These are group ids, that define a group of samples that are believed to be correlated through unobesrved factors. For instance, in randomized experiments when a whole class is being treated and we have data at the student level, the students in a class constitute a cluster, as their outcome variables are most probably correlated. In such settings, it is helpful to account for the cluster correlations when calculating confidence intervals and when performing sample splitting for either cross-validation or for nuisance estimation. This notebook will not deal with such group clusters. A more tailored analysis is required.\n",
    "\n",
    "We will be assuming throughout that conditional ignorability is satisfied if we control for all the variables $X$, i.e. the potential outcomes $Y(1), Y(0)$ satisfy\n",
    "\\begin{align}\n",
    "Y(1), Y(0) ~\\perp\\hspace{-1em}\\perp~D \\mid X\n",
    "\\end{align}\n",
    "Equivalently, we assume that the DAG the corresponds to our setting satisfies that $X$ is a valid adjustment set between $D$ and $Y$, i.e. it blocks all backdoor paths in the DAG."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PwBWUvbB2_55"
   },
   "outputs": [],
   "source": [
    "X.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "YyMxyYwKXKfN"
   },
   "outputs": [],
   "source": [
    "if control_feats != 'all':\n",
    "    X = X[control_feats]\n",
    "    print(X.describe())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "HzQ-w_uMYvu0"
   },
   "outputs": [],
   "source": [
    "X = X - X.mean(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ysDh_Cea2_55"
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(15, 5))\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.hist(D)\n",
    "plt.title('Treatment')\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.hist(y)\n",
    "plt.title('Outcome')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "yvczDYFd2_55"
   },
   "outputs": [],
   "source": [
    "# simple two means estimate which would be wrong unless an randomized trial\n",
    "OLS(y, np.hstack([np.ones((D.shape[0], 1)), D.reshape(-1, 1)])).fit(cov_type='HC1').summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pHmIX_Ft2_55"
   },
   "source": [
    "# Nuisance Cross-Fitted Estimation and Prediction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tr-8lvK22_55"
   },
   "source": [
    "We will estimate regression models for each of the nuisance functions that arise in CATE learning approaches. The five models correspond to the following five predictive problems:\n",
    "\\begin{align}\n",
    "\\text{model_y} ~\\rightarrow~& q(x) := E[Y\\mid X=x]\\\\\n",
    "\\text{model_t} ~\\rightarrow~& p(x) := E[D\\mid X=x] = \\Pr(D=1\\mid X=x)\\\\\n",
    "\\text{model_reg_zero} ~\\rightarrow~& g_0(x) := E[Y\\mid D=0, X=x]\\\\\n",
    "\\text{model_reg_one} ~\\rightarrow~& g_1(x) := E[Y\\mid D=1, X=x]\\\\\n",
    "\\end{align}\n",
    "We will use gradient boosting regression with early stopping for each of these models."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YKEollbs2_55"
   },
   "source": [
    "For each of the nuisance models we perform cross-fitting to get out-of-fold predictions from each of these nuisance models. At the end of this process, we will have for each sample $i$, the following out-of-fold nuisance values:\n",
    "\\begin{align}\n",
    "\\text{reg_preds_t} \\rightarrow~& \\hat{g}_0(X_i) (1 - D_i) + \\hat{g}_1(X_i) D_i &\n",
    "\\text{reg_one_preds_t} \\rightarrow~& \\hat{g}_1(X_i) &\n",
    "\\text{reg_zero_preds_t} \\rightarrow~& \\hat{g}_0(X_i)\\\\\n",
    "\\text{res_preds} \\rightarrow~& \\hat{q}(X_i) &\n",
    "\\text{prop_preds} \\rightarrow~& \\hat{p}(X_i)\n",
    "\\end{align}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3rSVD8gO2_56"
   },
   "outputs": [],
   "source": [
    "from myxgb import xgb_reg, xgb_clf, RegWrapper\n",
    "\n",
    "\n",
    "def auto_reg():\n",
    "    return xgb_reg(random_seed)\n",
    "\n",
    "\n",
    "# Disclaimer: The remainder of the code assumes that the `auto_clf` model returns\n",
    "# the probability of class 1, when one calls auto_clf().predict(X)\n",
    "# and not the 0/1 classification. This is what the RegWrapper(xgb_clf()) class does.\n",
    "def auto_clf():\n",
    "    return RegWrapper(xgb_clf(random_seed))\n",
    "\n",
    "\n",
    "modely = auto_clf if binary_y else auto_reg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "GHIzWx_D2_56"
   },
   "outputs": [],
   "source": [
    "if cfit:\n",
    "    cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_seed)\n",
    "    stratification_label = 2 * D + y if binary_y else D\n",
    "    splits = list(cv.split(X, stratification_label))\n",
    "else:\n",
    "    splits = [(np.arange(X.shape[0]), np.arange(X.shape[0]))]\n",
    "\n",
    "n = X.shape[0]\n",
    "reg_preds_t = np.zeros(n)\n",
    "reg_zero_preds_t = np.zeros(n)\n",
    "reg_one_preds_t = np.zeros(n)\n",
    "\n",
    "for train, test in splits:\n",
    "    reg_zero = modely().fit(X.iloc[train][D[train] == 0], y[train][D[train] == 0])\n",
    "    reg_one = modely().fit(X.iloc[train][D[train] == 1], y[train][D[train] == 1])\n",
    "    reg_zero_preds_t[test] = reg_zero.predict(X.iloc[test])\n",
    "    reg_one_preds_t[test] = reg_one.predict(X.iloc[test])\n",
    "    reg_preds_t[test] = reg_zero_preds_t[test] * (1 - D[test]) + reg_one_preds_t[test] * D[test]\n",
    "\n",
    "res_preds = cross_val_predict(modely(), X, y, cv=splits)\n",
    "prop_preds = cross_val_predict(auto_clf(), X, D, cv=splits)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5SnpS8QS3bMz"
   },
   "source": [
    "# Evaluating Nuisance Model Performance\n",
    "\n",
    "We now also evaluate the performance of the selected models in terms of R^2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "IYqGlo-Y3XRH"
   },
   "outputs": [],
   "source": [
    "def r2score(y, ypred):\n",
    "    return 1 - np.mean((y - ypred)**2) / np.var(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "NSi-SFXa3Vk3"
   },
   "outputs": [],
   "source": [
    "print(f\"R^2 of model for (y ~ X): {r2score(y, res_preds):.4f}\")\n",
    "print(f\"R^2 of model for (D ~ X): {r2score(D, prop_preds):.4f}\")\n",
    "print(f\"R^2 of model for (y ~ X | D==0): {r2score(y[D==0], reg_zero_preds_t[D==0]):.4f}\")\n",
    "print(f\"R^2 of model for (y ~ X | D==1): {r2score(y[D==1], reg_one_preds_t[D==1]):.4f}\")\n",
    "print(f\"R^2 of model for (y ~ D, X): {r2score(y, reg_preds_t):.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2EytRrmS2_56"
   },
   "source": [
    "# Doubly-Robust ATE Estimation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "XGDHhMSH2_56"
   },
   "source": [
    "Using the doubly robust method. In particular, we construct the doubly robust variables\n",
    "\\begin{align}\n",
    "Y_i^{DR}(\\hat{g},\\hat{p}) := \\hat{g}_1(X_i) - \\hat{g}_0(X_i) + (Y_i - \\hat{g}_{D_i}(X_i))\\frac{D_i - \\hat{p}(X_i)}{\\hat{p}(X_i) (1-\\hat{p}(X_i))}\n",
    "\\end{align}\n",
    "and then we estimate:\n",
    "\\begin{align}\n",
    "ATE = E_n\\left[Y^{DR}(\\hat{g},\\hat{p})\\right]\n",
    "\\end{align}\n",
    "This should be more efficient in the worst-case and should be returning a consistent estimate of the ATE even beyond RCTs and will also correctly account for any imbalances or violations of the randomization assumption in an RCT."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "syeQPPmL2_56"
   },
   "outputs": [],
   "source": [
    "dr_preds = reg_one_preds_t - reg_zero_preds_t\n",
    "dr_preds += (y - reg_preds_t) * (D - prop_preds) / np.clip(prop_preds * (1 - prop_preds), cov_clip, np.inf)\n",
    "\n",
    "OLS(dr_preds, np.ones((len(dr_preds), 1))).fit(cov_type='HC1').summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ec3Ul72W2_56"
   },
   "source": [
    "# Best Linear CATE Predictor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Jd-GJWER2_56"
   },
   "source": [
    "We can also use the doubly robust variables as pseudo-outcomes in an OLS regression, so as to estimate the best linear approximation of the true CATE. In an RCT, these should be similar to the coefficients recovered in a plain interactive OLS regression."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5MC7zEvX2_56"
   },
   "outputs": [],
   "source": [
    "dfX = X.copy()\n",
    "dfX['const'] = 1\n",
    "lr = OLS(dr_preds, dfX).fit(cov_type='HC1')\n",
    "cov = lr.get_robustcov_results(cov_type='HC1')\n",
    "lr.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WCJu-ajr2_56"
   },
   "source": [
    "# Simultaneous (Joint) Confidence Intervals\n",
    "We can also perform joint inference on all these parameters controlling the joint probability of failure of the confidence intervals by 95%."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "unfUkJFk2_56"
   },
   "outputs": [],
   "source": [
    "V = cov.cov_params()\n",
    "S = np.diag(np.diagonal(V)**(-1 / 2))\n",
    "epsilon = np.random.multivariate_normal(np.zeros(V.shape[0]), S @ V @ S, size=(1000))\n",
    "critical = np.percentile(np.max(np.abs(epsilon), axis=1), 95)\n",
    "stderr = np.diagonal(V)**(1 / 2)\n",
    "lb = cov.params - critical * stderr\n",
    "ub = cov.params + critical * stderr\n",
    "jointsummary = pd.DataFrame({'coef': cov.params,\n",
    "                             'std err': stderr,\n",
    "                             'lb': lb,\n",
    "                             'ub': ub,\n",
    "                             'statsig': ['' if ((l <= 0) & (0 <= u)) else '**' for (l, u) in zip(lb, ub)]},\n",
    "                            index=dfX.columns)\n",
    "jointsummary"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UWNdzosR2_57"
   },
   "source": [
    "# Confidence Intervals on BLP of CATE Predictions\n",
    "\n",
    "We can also produce confidence intervals for the predictions of the CATE at particular points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "CTRIRfei2_57"
   },
   "outputs": [],
   "source": [
    "grid = np.unique(np.percentile(dfX[xfeat], np.arange(0, 110, 20)))\n",
    "\n",
    "Zpd = pd.DataFrame(np.tile(np.median(dfX, axis=0, keepdims=True), (len(grid), 1)),\n",
    "                   columns=dfX.columns)\n",
    "Zpd[xfeat] = grid\n",
    "\n",
    "pred_df = lr.get_prediction(Zpd).summary_frame()\n",
    "preds, lb, ub = pred_df['mean'].values, pred_df['mean_ci_lower'].values, pred_df['mean_ci_upper'].values\n",
    "preds = preds.flatten()\n",
    "lb = lb.flatten()\n",
    "ub = ub.flatten()\n",
    "plt.errorbar(Zpd[xfeat], preds, yerr=(preds - lb, ub - preds))\n",
    "plt.xlabel(xfeat)\n",
    "plt.ylabel('Predicted CATE (at median value of other features)')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jmXVt9xT2_57"
   },
   "source": [
    "# Simultaneous (Joint) Confidence Intervals on BLP of CATE Predictions\n",
    "\n",
    "And even simultaneous inference on all these predictions that controls the joint failure probability of these confidence intervals to be at most 95%"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "YDmeaAsi2_57"
   },
   "outputs": [],
   "source": [
    "predsV = Zpd.values @ V @ Zpd.values.T\n",
    "predsS = np.diag(np.diagonal(predsV)**(-1 / 2))\n",
    "epsilon = np.random.multivariate_normal(np.zeros(predsV.shape[0]), predsS @ predsV @ predsS, size=(1000))\n",
    "critical = np.percentile(np.max(np.abs(epsilon), axis=1), 95)\n",
    "stderr = np.diagonal(predsV)**(1 / 2)\n",
    "lb = preds - critical * stderr\n",
    "ub = preds + critical * stderr\n",
    "\n",
    "plt.errorbar(Zpd[xfeat], preds, yerr=(preds - lb, ub - preds))\n",
    "plt.xlabel(xfeat)\n",
    "plt.ylabel('Predicted CATE (at median value of other features)')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "BeKRN-4-2_6D"
   },
   "source": [
    "# Simpler Best Linear Projections of CATE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "HLcRxLjI2_6E"
   },
   "outputs": [],
   "source": [
    "from statsmodels.formula.api import ols\n",
    "df = X.copy()\n",
    "df['dr'] = dr_preds\n",
    "lr = ols('dr ~ ' + blp_formula, df).fit(cov_type='HC1')\n",
    "lr.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "LH6xA_VC2_6E"
   },
   "outputs": [],
   "source": [
    "grid = np.unique(np.percentile(X[xfeat], np.arange(0, 102, 2)))\n",
    "Xpd = pd.DataFrame(np.tile(np.median(X, axis=0, keepdims=True), (len(grid), 1)),\n",
    "                   columns=X.columns)\n",
    "Xpd[xfeat] = grid\n",
    "pred_df = lr.get_prediction(Xpd).summary_frame(alpha=.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "F4o22epR2_6E"
   },
   "outputs": [],
   "source": [
    "plt.plot(Xpd[xfeat], pred_df['mean'])\n",
    "plt.fill_between(Xpd[xfeat], pred_df['mean_ci_lower'], pred_df['mean_ci_upper'], alpha=.4)\n",
    "plt.xlabel(xfeat + ' (other features fixed at median value)')\n",
    "plt.title('Predicted CATE BLP: cate ~' + blp_formula)\n",
    "plt.ylabel('CATE')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ir12CLlM2_6E"
   },
   "outputs": [],
   "source": [
    "from statsmodels.formula.api import ols\n",
    "df = X.copy()\n",
    "df['dr'] = dr_preds\n",
    "lr = ols('dr ~ ' + blp_formula2, df).fit(cov_type='HC1')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "AEHVWKe22_6E"
   },
   "outputs": [],
   "source": [
    "grid = np.unique(np.percentile(X[xfeat], np.arange(0, 102, 2)))\n",
    "Xpd = pd.DataFrame(np.tile(np.median(X, axis=0, keepdims=True), (len(grid), 1)),\n",
    "                   columns=X.columns)\n",
    "Xpd[xfeat] = grid\n",
    "pred_df2 = lr.get_prediction(Xpd).summary_frame(alpha=.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XOWJEJTT2_6E"
   },
   "outputs": [],
   "source": [
    "plt.plot(Xpd[xfeat], pred_df2['mean'])\n",
    "plt.fill_between(Xpd[xfeat], pred_df2['mean_ci_lower'], pred_df2['mean_ci_upper'], alpha=.4)\n",
    "plt.xlabel(xfeat + ' (other features fixed at median value)')\n",
    "plt.ylabel('CATE')\n",
    "plt.title('cate ~' + blp_formula2)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "U0H-C_j52_6E"
   },
   "source": [
    "# Non-Parametric Confidence Intervals on CATE Predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "TRnIbW5x2_6E"
   },
   "source": [
    "We now move on to the subject of constructing confidence intervals for the predictions of CATE models. Confidence intervals for CATE predictions is an inherently harder task. In its generality it is at least as hard as constructing confidence intervals for the predictions of a non-parametric regression function; which is a statistically daunting task.\n",
    "\n",
    "We will use data-adaptive approaches like random forests to side step the curse of dimensionality and potentially adapt to sparsity in the regression function (though theoretically such an adaptivity is in the worst case imposssible; it tends to work well in practice). This is the approach taken by CausalForests or Doubly Robust Forests that are both based on the idea of Generalized Random Forests, which is an extension of classical forests for solving problems defined via conditional moment restrictions."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "L4Kri6NH2_6E"
   },
   "source": [
    "# Non-Parametric Confidence Intervals with Causal Forests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "YPl6vc922_6E"
   },
   "outputs": [],
   "source": [
    "if hetero_feats == 'all':\n",
    "    Z = X\n",
    "else:\n",
    "    Z = X[hetero_feats]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "iI3dlYGk2_6E"
   },
   "outputs": [],
   "source": [
    "if Z.shape[0] > 1e6:\n",
    "    min_samples_leaf = 500\n",
    "    max_samples = 0.05\n",
    "else:\n",
    "    min_samples_leaf = 50\n",
    "    max_samples = .4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "8UXA0krD2_6E"
   },
   "outputs": [],
   "source": [
    "from econml.grf import CausalForest\n",
    "\n",
    "yres = y - res_preds\n",
    "Dres = D - prop_preds\n",
    "cf = CausalForest(4000, criterion='het', max_depth=None,\n",
    "                  max_samples=max_samples,\n",
    "                  min_samples_leaf=min_samples_leaf,\n",
    "                  min_weight_fraction_leaf=.0,\n",
    "                  random_state=random_seed)\n",
    "cf.fit(Z, Dres, yres)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "55jW99F12_6E"
   },
   "outputs": [],
   "source": [
    "top_feat = np.argsort(cf.feature_importances_)[-1]\n",
    "print(Z.columns[top_feat])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "nwsoM6xt2_6F"
   },
   "outputs": [],
   "source": [
    "grid = np.unique(np.percentile(Z.iloc[:, top_feat], np.arange(0, 105, 5)))\n",
    "Zpd = pd.DataFrame(np.tile(np.median(Z, axis=0, keepdims=True), (len(grid), 1)),\n",
    "                   columns=Z.columns)\n",
    "Zpd.iloc[:, top_feat] = grid\n",
    "\n",
    "preds, lb, ub = cf.predict(Zpd, interval=True, alpha=.1)\n",
    "preds = preds.flatten()\n",
    "lb = lb.flatten()\n",
    "ub = ub.flatten()\n",
    "plt.errorbar(Zpd.iloc[:, top_feat], preds, yerr=(preds - lb, ub - preds))\n",
    "plt.xlabel(Zpd.columns[top_feat])\n",
    "plt.ylabel('Predicted CATE (at median value of other features)')\n",
    "plt.savefig(f'{data}-causal-forest.png', dpi=600)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fdV2venN2_6F"
   },
   "outputs": [],
   "source": [
    "if semi_synth:\n",
    "    true_proj = true_cate(X)\n",
    "    preds, lb, ub = cf.predict(Z, interval=True, alpha=.1)\n",
    "    preds = preds.flatten()\n",
    "    lb = lb.flatten()\n",
    "    ub = ub.flatten()\n",
    "    inds = np.argsort(true_proj)\n",
    "    plt.plot(true_proj[inds], preds[inds])\n",
    "    plt.fill_between(true_proj[inds], lb[inds].flatten(), ub[inds].flatten(), alpha=.4)\n",
    "    plt.plot(np.linspace(np.min(true_proj), np.max(true_proj), 100),\n",
    "             np.linspace(np.min(true_proj), np.max(true_proj), 100))\n",
    "    plt.xlabel('True CATE')\n",
    "    plt.ylabel('Predicted CATE')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "EDsjdk_52_6F"
   },
   "outputs": [],
   "source": [
    "important_feats = Z.columns[np.argsort(cf.feature_importances_)[::-1]]\n",
    "important_feats[:4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "8tuWZEX12_6F"
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 7))\n",
    "for it, feature in enumerate(important_feats[:4]):\n",
    "    plt.subplot(2, 2, it + 1)\n",
    "    grid = np.unique(np.percentile(Z[feature], np.arange(0, 105, 5)))\n",
    "    Zpd = pd.DataFrame(np.tile(np.median(Z, axis=0, keepdims=True), (len(grid), 1)),\n",
    "                       columns=Z.columns)\n",
    "    Zpd[feature] = grid\n",
    "\n",
    "    preds, lb, ub = cf.predict(Zpd, interval=True, alpha=.1)\n",
    "    preds = preds.flatten()\n",
    "    lb = lb.flatten()\n",
    "    ub = ub.flatten()\n",
    "    plt.errorbar(Zpd[feature], preds, yerr=(preds - lb, ub - preds))\n",
    "    plt.xlabel(feature)\n",
    "    plt.ylabel('Predicted CATE')\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'{data}-cf-marginal-plots.png', dpi=600)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "owgItlJv2_6F"
   },
   "source": [
    "# Non-Parametric Confidence Intervals with Doubly Robust Forests"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GqwRym_E2_6F"
   },
   "source": [
    "(standard errors here ignore cluster/group correlations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "C1pZAGCB2_6F"
   },
   "outputs": [],
   "source": [
    "from econml.grf import RegressionForest\n",
    "\n",
    "drrf = RegressionForest(4000, max_depth=5,\n",
    "                        max_samples=max_samples,\n",
    "                        min_samples_leaf=min_samples_leaf,\n",
    "                        min_weight_fraction_leaf=.0,\n",
    "                        random_state=random_seed)\n",
    "drrf.fit(Z, dr_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3OxZFkQ82_6F"
   },
   "outputs": [],
   "source": [
    "top_feat = np.argsort(drrf.feature_importances_)[-1]\n",
    "print(Z.columns[top_feat])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "K4hZOq3Z2_6F"
   },
   "outputs": [],
   "source": [
    "grid = np.unique(np.percentile(Z.iloc[:, top_feat], np.arange(0, 105, 5)))\n",
    "Zpd = pd.DataFrame(np.tile(np.median(Z.values, axis=0, keepdims=True), (len(grid), 1)),\n",
    "                   columns=Z.columns)\n",
    "Zpd.iloc[:, top_feat] = grid\n",
    "\n",
    "preds, lb, ub = drrf.predict(Zpd, interval=True, alpha=.1)\n",
    "preds = preds.flatten()\n",
    "lb = lb.flatten()\n",
    "ub = ub.flatten()\n",
    "plt.errorbar(Zpd.iloc[:, top_feat], preds, yerr=(preds - lb, ub - preds))\n",
    "plt.xlabel(Zpd.columns[top_feat])\n",
    "plt.ylabel('Predicted CATE (at median value of other features)')\n",
    "plt.savefig(f'{data}-doubly-robust-forest.png', dpi=600)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "yn_dmwD12_6F"
   },
   "outputs": [],
   "source": [
    "important_feats = Z.columns[np.argsort(drrf.feature_importances_)[::-1]]\n",
    "important_feats[:4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "eVw-mkQY2_6F"
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 7))\n",
    "for it, feature in enumerate(important_feats[:4]):\n",
    "    plt.subplot(2, 2, it + 1)\n",
    "    grid = np.unique(np.percentile(Z[feature], np.arange(0, 105, 5)))\n",
    "    Zpd = pd.DataFrame(np.tile(np.median(Z, axis=0, keepdims=True), (len(grid), 1)),\n",
    "                       columns=Z.columns)\n",
    "    Zpd[feature] = grid\n",
    "\n",
    "    preds, lb, ub = drrf.predict(Zpd, interval=True, alpha=.1)\n",
    "    preds = preds.flatten()\n",
    "    lb = lb.flatten()\n",
    "    ub = ub.flatten()\n",
    "    plt.errorbar(Zpd[feature], preds, yerr=(preds - lb, ub - preds))\n",
    "    plt.xlabel(feature)\n",
    "    plt.ylabel('Predicted CATE')\n",
    "plt.tight_layout()\n",
    "plt.savefig(f'{data}-drrf-marginal-plots.png', dpi=600)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "mAjrHKfA2_6F"
   },
   "outputs": [],
   "source": [
    "if semi_synth:\n",
    "    true_proj = true_cate(X)\n",
    "    preds, lb, ub = drrf.predict(Z, interval=True, alpha=.1)\n",
    "    preds = preds.flatten()\n",
    "    lb = lb.flatten()\n",
    "    ub = ub.flatten()\n",
    "    inds = np.argsort(true_proj)\n",
    "    plt.plot(true_proj[inds], preds[inds])\n",
    "    plt.fill_between(true_proj[inds], lb[inds].flatten(), ub[inds].flatten(), alpha=.4)\n",
    "    plt.plot(np.linspace(np.min(true_proj), np.max(true_proj), 100),\n",
    "             np.linspace(np.min(true_proj), np.max(true_proj), 100))\n",
    "    plt.xlabel('True CATE')\n",
    "    plt.ylabel('Predicted CATE')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cfnCsNVD2_6F"
   },
   "source": [
    "# Policy Evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kt3-0rq12_6G"
   },
   "source": [
    "Suppose our goal is to estimate the best treatment policy $\\pi: Z \\to \\{0, 1\\}$. The policy gains over no treatment for any policy $\\pi$ can be identified as:\n",
    "\\begin{align}\n",
    "V(\\pi) := E[\\pi(Z)\\, (Y(1) - Y(0))] = E\\left[\\pi(Z)\\, Y^{DR}(g,p)\\right]\n",
    "\\end{align}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6ier8uoA2_6G"
   },
   "source": [
    "#### Evaluating some personalized policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "udvQSs_E2_6G"
   },
   "outputs": [],
   "source": [
    "pi = (dr_preds - treatment_cost) * policy(Z)\n",
    "point = np.mean(pi)\n",
    "stderr = np.sqrt(np.var(pi) / pi.shape[0])\n",
    "print(f\"{point:.5f}, {stderr:.5f}, {point - 1.96 * stderr:.5f}, {point + 1.96 * stderr:.5f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zo7To0gD2_6G"
   },
   "source": [
    "#### As compared to treating everyone"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xeNoEkin2_6G"
   },
   "outputs": [],
   "source": [
    "pi = (dr_preds - treatment_cost)\n",
    "point = np.mean(pi)\n",
    "stderr = np.sqrt(np.var(pi) / pi.shape[0])\n",
    "print(f\"{point:.5f}, {stderr:.5f}, {point - 1.96 * stderr:.5f}, {point + 1.96 * stderr:.5f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "N1MMjO-uxj1p"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "A100",
   "machine_shape": "hm",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
